# Rekall Memory Forensics
# Copyright 2016 Google Inc. All Rights Reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or (at
# your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#

from builtins import range
from builtins import object
__author__ = "Michael Cohen <scudette@gmail.com>"
import yara

from rekall import plugin
from rekall import scan

from rekall.plugins import yarascanner
from rekall.plugins.common import pfn
from rekall.plugins.tools import yara_support
from rekall.plugins.windows import common
from rekall.plugins.windows import pagefile

from rekall_lib import utils


class WinYaraScan(yarascanner.YaraScanMixin, common.WinScanner):
    """Scan using yara signatures."""

    scanner_defaults = dict(
        scan_physical=True
    )


class ContextBuffer(object):
    """A class to manage hits and create contiguous context buffers."""

    def __init__(self, session):
        self._context_cache = utils.FastStore(max_size=10000)
        self.last_pfn_id = -1
        self.last_context_list = None
        self.hits_by_context = {}
        self.session = session
        self.address_space = session.physical_address_space

    def _add_hit_offset(self, context_list, string_name, original_offset,
                        value):
        for context in context_list:
            hits_by_context_dict = self.hits_by_context.setdefault(context, {})

            if string_name not in hits_by_context_dict:
                hits_by_context_dict[string_name] = (original_offset,
                                                     value.encode("base64"))

    def add_hit(self, string_name, hit_offset, value):
        pfn_id = hit_offset >> 12
        if pfn_id == self.last_pfn_id:
            if self.last_context_list is not None:
                self._add_hit_offset(
                    self.last_context_list, string_name, hit_offset, value)
        else:
            self.last_pfn_id = pfn_id
            self.last_context_list = self.get_contexts(pfn_id << 12)
            if self.last_context_list:
                self._add_hit_offset(
                    self.last_context_list, string_name, hit_offset, value)
            else:
                self.session.logging.debug(
                    "No process context for hit at %#x", hit_offset)

    def get_combined_context_buffers(self):
        """Yields pseudo_data for each context containing all hits."""

        pad = "\xFF" * 10
        for context, hits_dict in self.hits_by_context.items():
            data = []
            data_len = 0
            # Map the original offset to the dummy buffer offset.
            omap = {}
            for hit_offset, encoded_value in hits_dict.values():
                omap[data_len] = hit_offset

                value = encoded_value.decode("base64")
                # Some padding separates out the sigs.
                data.append(value)
                data.append(pad)

                data_len += len(value) + len(pad)
            yield context, omap, "".join(data)

    def process_owners_from_physical_address(self, address):
        """Get the process owner from the physical address.

        We could use the ptov() or rammap() plugin but this is a very fast
        implementation which only cares about the identity of the owner.
        """
        pfn_id = address >> 12
        try:
            return self._context_cache.Get(pfn_id)
        except KeyError:
            pass

        # Try to find a process that owns this page. This is an optimized
        # version of the algorithm in the pfn, ptov and rammap plugins.
        pfn_database = self.session.profile.get_constant_object(
            "MmPfnDatabase")
        pfn_obj = pfn_database[pfn_id]
        # This is a mapped file.
        if pfn_obj.IsPrototype:
            # This is the controlling PTE.
            pte_address = pfn_obj.PteAddress.v()
            try:
                # All PTEs in that page are owned by the same owners.
                return self._context_cache.Get(pte_address >> 12)
            except KeyError:
                descriptor = pagefile.WindowsFileMappingDescriptor(
                    session=self.session, pte_address=pte_address)

                owners = [x[0] for x in descriptor.get_owners()]
                self._context_cache.Put(pte_address >> 12, owners)
                return owners

        # We only care about the process owner so this is the first half of
        # pfn.ptov._ptov_x64_hardware_PTE()
        p_addr = address
        pfns = []

        for _ in range(4):
            pfn_id = p_addr >> 12
            try:
                owners = self._context_cache.Get(pfn_id)
                for pfn_id in pfns:
                    self._context_cache.Put(pfn_id, owners)

                return owners
            except KeyError:
                pass

            pfn_obj = pfn_database[pfn_id]
            pfns.append(pfn_id)

            # The PTE which controls this pfn.
            pte = pfn_obj.PteAddress

            # The physical address of the PTE.
            p_addr = ((pfn_obj.u4.PteFrame << 12) |
                      (pte.v() & 0xFFF))

        # The DTB must be page aligned.
        descriptor = pagefile.WindowsDTBDescriptor(
            session=self.session, dtb=p_addr & ~0xFFF)

        owners = [descriptor.owner()]
        for pfn_id in pfns:
            self._context_cache.Put(pfn_id, owners)

        return owners

    def get_contexts(self, offset):
        """Get some context about this offset.

        We use this context to group similar yara hits into logical groups.

        Returns:
          a list of things which can be used as contexts - i.e. they are unique
          for all pages common within this context. Pages will be grouped by
          these contexts and evaluated together.
        """
        owners = self.process_owners_from_physical_address(offset)
        if not owners:
            return []

        return [x.obj_offset for x in owners]


class WinPhysicalYaraScanner(common.AbstractWindowsCommandPlugin):
    """An experimental yara scanner over the physical address space.

    Yara does not provide a streaming interface, which means that when we scan
    for yara rules we can only ever match strings within the same buffer. This
    is a problem for physical address space scanning because each page (although
    it might appear to be contiguous) usually comes from a different
    process/mapped file.

    Therefore we need a more intelligent way to apply yara signatures on the
    physical address space:

    1. The original set of yara rules is converted into a single rule with all
    the strings from all the rules in it. The rule has a condition "any of them"
    which will match any string appearing in the scanned buffer.

    2. This rule is then applied over the physical address space.

    3. For each hit we derive a context and add the hit to the context.

    4. Finally we test all the rules within the same context with the original
    rule set.
    """

    name = "yarascan_physical"

    table_header = [
        dict(name="Owner", width=20),
        dict(name="Rule", width=10),
        dict(name="Offset", style="address"),
        dict(name="HexDump", hex_width=16, width=67),
        dict(name="Context"),
    ]

    __args = [
        dict(name="hits", default=10, type="IntParser",
             help="Quit after finding this many hits."),

        dict(name="yara_expression",
             help="If provided we scan for this yara "
             "expression specified in the yara DSL."),

        dict(name="yara_ast",
             help="If provided we scan for this yara "
             "expression specified in the yara JSON AST."),

        dict(name="start", default=0, type="IntParser",
             help="Start searching from this offset."),

        dict(name="limit", default=2**64, type="IntParser",
             help="The length of data to search."),

        dict(name="context", default=0x40, type="IntParser",
             help="Context to print after the hit."),

        dict(name="pre_context", default=0, type="IntParser",
             help="Context to print before the hit."),
    ]

    scanner_defaults = dict(
        scan_physical=True
    )

    def __init__(self, *args, **kwargs):
        super(WinPhysicalYaraScanner, self).__init__(*args, **kwargs)
        try:
            # The user gave a yara DSL rule.
            if self.plugin_args.yara_expression:
                self.rules = yara.compile(
                    source=self.plugin_args.yara_expression)

                self.parsed_rules = yara_support.parse_yara_to_ast(
                    self.plugin_args.yara_expression)

            # User gave a yara AST.
            elif self.plugin_args.yara_ast:
                self.parsed_rules = self.plugin_args.yara_ast
                self.rules = yara.compile(
                    source=yara_support.ast_to_yara(self.parsed_rules))
            else:
                raise plugin.PluginError("A yara expression must be provided.")

            all_strings = []
            rule_id = 0
            for parsed_rule in self.parsed_rules:
                name = parsed_rule["name"]
                for k, v in parsed_rule["strings"]:
                    rule_name = "%s_%d_REKALL_%s" % (k, rule_id, name)
                    all_strings.append((rule_name, v))
                    rule_id += 1

            self.parsed_unified_rule = [
                dict(name="XX",
                     strings=all_strings,
                     condition="any of them")
            ]
            self.plugin_args.unified_yara_expression = (
                yara_support.ast_to_yara(self.parsed_unified_rule))

            self.unified_rule = yara.compile(
                source=self.plugin_args.unified_yara_expression)

            self.context_buffer = ContextBuffer(self.session)

        except Exception as e:
            raise plugin.PluginError(
                "Failed to compile yara expression: %s" % e)

    def collect(self):
        address_space = self.session.physical_address_space
        for buffer_as in scan.BufferASGenerator(
                self.session, address_space,
                self.plugin_args.start,
                self.plugin_args.start + self.plugin_args.limit):
            self.session.report_progress(
                "Scanning buffer %#x->%#x (%#x)",
                buffer_as.base_offset, buffer_as.end(),
                buffer_as.end() - buffer_as.base_offset)
            for match in self.unified_rule.match(data=buffer_as.data):
                for buffer_offset, string_name, value in sorted(match.strings):
                    hit_offset = buffer_offset + buffer_as.base_offset
                    self.context_buffer.add_hit(string_name, hit_offset, value)

        # Now re-run the original expression on all unique contexts.
        it = self.context_buffer.get_combined_context_buffers()
        for context, original_offset_map, pseudo_data in it:
            seen = set()
            self.session.report_progress(
                "Scanning pseudo buffer of length %d" % len(pseudo_data))
            # Report any hits of the original sig on this context.
            for match in self.rules.match(data=pseudo_data):
                self.session.report_progress()
                # Only report a single hit of the same rule on the same context.
                dedup_key = (match.rule, context)
                if dedup_key in seen:
                    continue

                seen.add(dedup_key)
                for buffer_offset, _, value in match.strings:
                    hit_offset = original_offset_map.get(buffer_offset)
                    if hit_offset is not None:
                        if isinstance(context, int):
                            owner = self.session.profile._EPROCESS(context)
                        else:
                            owner = context

                        yield dict(
                            Owner=owner,
                            Rule=match.rule,
                            Offset=hit_offset,
                            HexDump=utils.HexDumpedString(
                                address_space.read(
                                    hit_offset - self.plugin_args.pre_context,
                                    self.plugin_args.context +
                                    self.plugin_args.pre_context)),
                            Context=pfn.PhysicalAddressContext(
                                self.session, hit_offset)
                        )
